from solver import solver
import torch
from src.model import SODModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

SINet_model = SODModel()
# SINet_model.load_state_dict(torch.load("./model_pth/29.pth"))
SI_solver = solver(SINet_model, device, epoch=30, lr=1e-4)

# 训练
SI_solver.train("./models/")

# 将图片结果输出到output/目录下
SI_solver.test("./output/")